package com.capinfo.accumulation.plugin;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.Map;
import java.util.Properties;

import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.log4j.Logger;

import com.capinfo.accumulation.parameter.accounting.DataGrid;

/**
 * mybits 分页插件
 */
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })
public class GridInterceptor implements Interceptor {
    /**
     * Logger for this class
     */
    private static final Logger logger = Logger.getLogger(GridInterceptor.class);

    private String databaseType;// 数据库类型，不同的数据库有不同的分页方法

    private Map<String, String> operator;
    /**
     * 拦截后要执行的方法
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        // 配置文件中SQL语句的ID

        String id = mappedStatement.getId();
        if (id.matches(".+ByGrid$")) {
            BoundSql boundSql = statementHandler.getBoundSql();
            // 原始的SQL语句
            String sql = boundSql.getSql();
            Map<?, ?> parameter = (Map<?, ?>) boundSql.getParameterObject();
            //
            String packSql = this.getPackageSql(sql);
            //条件sql
            String operSql=getOperatorSql(parameter,packSql);
            //总记录数sql语句
            String countSql = this.getCountSql(operSql);
            //排序sql语句获得
            String sort = (String) parameter.get("sort");//排序字段
            String order = (String) parameter.get("order");// asc/desc
            String sortSql = this.getSortSql(sort,order, operSql);
            //排序sql语句获得
            Integer page = null,rows=null;
            if(parameter.get("currentPieceNum")!=null && parameter.get("perPieceSize")!=null ){
        	    page = Integer.parseInt( parameter.get("currentPieceNum").toString());//当前页
                rows = Integer.parseInt( parameter.get("perPieceSize").toString());// 每页显示记录数
            }
            String pageSql = this.getPageSql(page,rows, sortSql);
            logger.info("条件字符串" + pageSql);
            // 查询总条数的SQL语句
            Connection connection = (Connection) invocation.getArgs()[0];
            PreparedStatement countStatement = connection.prepareStatement(countSql);
            ParameterHandler parameterHandler = (ParameterHandler) metaObject.getValue("delegate.parameterHandler");
            parameterHandler.setParameters(countStatement);
            ResultSet rs = countStatement.executeQuery();
            if (rs.next()) {
               DataGrid dataGrid = (DataGrid)parameter.get("dataGrid");
               if(dataGrid != null)
               dataGrid.setTotal(rs.getLong(1));
            }
            countStatement.close();
            rs.close();
            metaObject.setValue("delegate.boundSql.sql", pageSql);
        }
        return invocation.proceed();
    }

    /**

     * 拦截器对应的封装原始对象的方法

     */
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    /**
     * 获得排序sql
     * 
     * @param page
     * @param sql
     * @return
     */
    private String getSortSql(String sort,String order, String sql) {
        if (sort != null && order !=null) {
            StringBuffer sqlBuffer = new StringBuffer(sql);
            sqlBuffer.append(" order by ").append("a.").append(sort).append(" ").append(order);
            return sqlBuffer.toString();
        }
        return sql;
    }

    /**
     * 包装sql语句
     * 
     * @param sql
     * @return
     */
    private String getPackageSql(String sql) {
        return "select *  from (" + sql + ") a where 1=1 ";
    }

    /**
     * 增加查询条件
     * 
     * @param f
     * @param sql
     * @return
     */
    private String getOperatorSql(Map parameter, String sql) {
    	StringBuffer sqlBuffer = new StringBuffer(sql);
    	for ( Object key : parameter.keySet()) {  
    		if(key.toString().trim().equals("sort")
    				||key.toString().trim().equals("order")
    				||key.toString().trim().equals("dataGrid")
    				||key.toString().trim().equals("currentPieceNum")
    				||key.toString().trim().equals("perPieceSize"))	{
    			continue;
    		}else{
    			//1 判断数据类型
    			String[] strs = key.toString().trim().split("_");
    			if(strs.length>4){
    				logger.error("过滤字段出现多个下划线  ：" + key);
    				return sql;
    			}else if (strs.length>=3){
    				String name =strs[0];
    				String eq =strs[1];
    				String dataType =strs[2];
    				String dataFormat ="yyyyMMdd";
    				try {
						 dataFormat =strs[3];
					} catch (Exception e) {
					}
    				if(parameter.get(key).toString().length() == 0){
						continue;
					}
    				if(dataType.endsWith("NUM")){
    					sqlBuffer.append(" and ") .append( name ).append(" ").append(operator.get(eq).replace(operator.get("replacedata"), parameter.get(key).toString()));
    				}else if(dataType.endsWith("DATE")){
    					String querydate = parameter.get(key.toString()).toString().replaceAll("-", "").replaceAll(" ", "").replaceAll(":", "");
    					String date_format = "'"+dataFormat+"'";
    					sqlBuffer.append(" and ") .
    					append("to_char("+name+", "+date_format+" ) ")
    					.append(operator.get(eq).replace(operator.get("replacedata"), querydate));
    				}else if(name.endsWith("STR") ){
    					sqlBuffer.append(" and ") .append( name ).append(" ").append(operator.get(eq).replace(operator.get("replacedata"), parameter.get(key).toString()));
    				}else{
    					sqlBuffer.append(" and ") .append( name ).append(" ").append(operator.get(eq).replace(operator.get("replacedata"), parameter.get(key).toString()));
    				}
    				
    			}
    		}	
    	} 
    	return sqlBuffer.toString();
    }

    /**
     * 根据page对象获取对应的分页查询Sql语句，这里只做了两种数据库类型，Mysql和Oracle 其它的数据库都 没有进行分页
     *
     * @param page
     *            分页对象
     * @param sql
     *            原sql语句
     * @return
     */
    private String getPageSql(Integer page, Integer rows, String sql) {
        StringBuffer sqlBuffer = new StringBuffer(sql);
        
        if(page!=null && rows!=null &&databaseType .equals("MYSQL") ){
        	return getMysqlPageSql(page, rows ,sqlBuffer);
        }else if(page!=null && rows!=null &&databaseType .equals("ORACLE")){
        	return getOraclePageSql (page, rows ,sqlBuffer); 
        }else{
        	return sql;
        }
        
    }

    /**
     * 获取Mysql数据库的分页查询语句
     * 
     * @param page
     *            分页对象
     * @param sqlBuffer
     *            包含原sql语句的StringBuffer对象
     * @return Mysql数据库分页语句

     */
    private String getMysqlPageSql(Integer page, Integer rows, StringBuffer sqlBuffer) {
        // 计算第一条记录的位置，Mysql中记录的位置是从0开始的。
        int offset = (page - 1) * rows;
        sqlBuffer.append(" limit ").append(offset).append(",").append(rows);
        return sqlBuffer.toString();
    }

    /** 
     * 获取Oracle数据库的分页查询语句 
     * @param page 
     *            分页对象 
     * @param sqlBuffer 
     *            包含原sql语句的StringBuffer对象 
     * @return Oracle数据库的分页查询语句 
     */  
    private String getOraclePageSql(Integer page, Integer rows, StringBuffer sqlBuffer) {  
        // 计算第一条记录的位置，Oracle分页是通过rownum进行的，而rownum是从1开始的  
        int offset = (page - 1) *rows+ 1;  
        sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ")  
                .append(offset + rows);  
        sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);  
        // 上面的Sql语句拼接之后大概是这个样子：  
        // select * from (select u.*, rownum r from (select * from t_user) u  
        // where rownum < 31) where r >= 16  
        return sqlBuffer.toString();  
    }  
    /**
     * 根据原Sql语句获取对应的查询总记录数的Sql语句
     * 
     * @param sql
     * @return
     */
    private String getCountSql(String sql) {
        int index = sql.indexOf("from");
        return "select count(*) " + sql.substring(index);
    }

    @Override
    public void setProperties(Properties properties) {
        this.operator = (Map) properties;
        this.databaseType = properties.getProperty("databaseType");
    }

}